import os 
import abc
import math
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import timm.models.vision_transformer

from torch import einsum
from ldm.util import instantiate_from_config
from ldm.modules.diffusionmodules.util import extract_into_tensor
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from inspect import isfunction
from functools import partial
from omegaconf import OmegaConf
from models_moe import PatchEmbed, MoEnhanceBlock, MoEnhanceTaskBlock
from ldm.modules.attention import MoE_Transformer
from tqdm import tqdm
from timm.models.layers import DropPath, trunc_normal_
from timm.models.vision_transformer import Mlp
# from detectron2.layers import CNNBlockBase, Conv2d, get_norm
import fvcore.nn.weight_init as weight_init

class SimpleDecoding(nn.Module):
    def __init__(self, out_channels, dims, factor=2):
        super(SimpleDecoding, self).__init__()

        hidden_size = dims[-1]//factor
        c4_size = dims[-1] # 1280
        c3_size = dims[-2] # 1280
        c2_size = dims[-3] # 640
        c1_size = dims[-4] # 320

        self.conv1_4 = nn.Conv2d(c4_size+c3_size, hidden_size, 3, padding=1, bias=False)
        self.bn1_4 = nn.BatchNorm2d(hidden_size)
        self.relu1_4 = nn.ReLU()
        self.conv2_4 = nn.Conv2d(hidden_size, hidden_size, 3, padding=1, bias=False)
        self.bn2_4 = nn.BatchNorm2d(hidden_size)
        self.relu2_4 = nn.ReLU()

        self.conv1_3 = nn.Conv2d(hidden_size + c2_size, hidden_size, 3, padding=1, bias=False)
        self.bn1_3 = nn.BatchNorm2d(hidden_size)
        self.relu1_3 = nn.ReLU()
        self.conv2_3 = nn.Conv2d(hidden_size, hidden_size, 3, padding=1, bias=False)
        self.bn2_3 = nn.BatchNorm2d(hidden_size)
        self.relu2_3 = nn.ReLU()

        self.conv1_2 = nn.Conv2d(hidden_size + c1_size, hidden_size, 3, padding=1, bias=False)
        self.bn1_2 = nn.BatchNorm2d(hidden_size)
        self.relu1_2 = nn.ReLU()
        self.conv2_2 = nn.Conv2d(hidden_size, hidden_size, 3, padding=1, bias=False)
        self.bn2_2 = nn.BatchNorm2d(hidden_size)
        self.relu2_2 = nn.ReLU()
        
        self.conv1_1 = nn.Conv2d(hidden_size, out_channels, 1, bias = False)

    def forward(self, x_c4, x_c3, x_c2, x_c1):
        # fuse Y4 and Y3
        if x_c4.size(-2) < x_c3.size(-2) or x_c4.size(-1) < x_c3.size(-1):
            x_c4 = F.interpolate(input=x_c4, size=(x_c3.size(-2), x_c3.size(-1)), mode='bilinear', align_corners=True)
        x = torch.cat([x_c4, x_c3], dim=1)
        x = self.conv1_4(x)
        x = self.bn1_4(x)
        x = self.relu1_4(x)
        x = self.conv2_4(x)
        x = self.bn2_4(x)
        x = self.relu2_4(x)

        # fuse top-down features and Y2 features
        if x.size(-2) < x_c2.size(-2) or x.size(-1) < x_c2.size(-1):
            x = F.interpolate(input=x, size=(x_c2.size(-2), x_c2.size(-1)), mode='bilinear', align_corners=True)
        x = torch.cat([x, x_c2], dim=1)
        x = self.conv1_3(x)
        x = self.bn1_3(x)
        x = self.relu1_3(x)
        x = self.conv2_3(x)
        x = self.bn2_3(x)
        x = self.relu2_3(x)
        
        # fuse top-down features and Y1 features
        if x.size(-2) < x_c1.size(-2) or x.size(-1) < x_c1.size(-1):
            x = F.interpolate(input=x, size=(x_c1.size(-2), x_c1.size(-1)), mode='bilinear', align_corners=True)
        x = torch.cat([x, x_c1], dim=1)
        x = self.conv1_2(x)
        x = self.bn1_2(x)
        x = self.relu1_2(x)
        x = self.conv2_2(x)
        x = self.bn2_2(x)
        x = self.relu2_2(x) # (BS, 640, 32, 32)

        # output layer
        x = self.conv1_1(x)

        return x


def exists(val):
    return val is not None


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def register_attention_control(model, controller):
    def ca_forward(self, place_in_unet):
        def forward(x, context=None, mask=None):
            h = self.heads

            q = self.to_q(x)
            is_cross = context is not None
            context = default(context, x)
            k = self.to_k(context)
            v = self.to_v(context)

            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

            sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

            if exists(mask):
                mask = rearrange(mask, 'b ... -> b (...)')
                max_neg_value = -torch.finfo(sim.dtype).max
                mask = repeat(mask, 'b j -> (b h) () j', h=h)
                sim.masked_fill_(~mask, max_neg_value)

            # attention, what we cannot get enough of
            attn = sim.softmax(dim=-1)
            
            attn2 = rearrange(attn, '(b h) k c -> h b k c', h=h).mean(0)
            controller(attn2, is_cross, place_in_unet)

            out = einsum('b i j, b j d -> b i d', attn, v)
            out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
            return self.to_out(out)

        return forward

    class DummyController:
        def __call__(self, *args):
            return args[0]

        def __init__(self):
            self.num_att_layers = 0

    if controller is None:
        controller = DummyController()

    def register_recr(net_, count, place_in_unet):
        if net_.__class__.__name__ == 'CrossAttention':
            net_.forward = ca_forward(net_, place_in_unet)
            return count + 1
        elif hasattr(net_, 'children'):
            for net__ in net_.children():
                count = register_recr(net__, count, place_in_unet)
        return count

    cross_att_count = 0
    sub_nets = model.diffusion_model.named_children()

    for net in sub_nets:
        if "input_blocks" in net[0]:
            cross_att_count += register_recr(net[1], 0, "down")
        elif "output_blocks" in net[0]:
            cross_att_count += register_recr(net[1], 0, "up")
        elif "middle_block" in net[0]:
            cross_att_count += register_recr(net[1], 0, "mid")

    controller.num_att_layers = cross_att_count


class AttentionControl(abc.ABC):
    
    def step_callback(self, x_t):
        return x_t
    
    def between_steps(self):
        return
    
    @property
    def num_uncond_att_layers(self):
        return 0
    
    @abc.abstractmethod
    def forward (self, attn, is_cross: bool, place_in_unet: str):
        raise NotImplementedError
    
    def __call__(self, attn, is_cross: bool, place_in_unet: str):
        attn = self.forward(attn, is_cross, place_in_unet)
        return attn
    
    def reset(self):
        self.cur_step = 0
        self.cur_att_layer = 0

    def __init__(self):
        self.cur_step = 0
        self.num_att_layers = -1
        self.cur_att_layer = 0


class AttentionStore(AttentionControl):
    @staticmethod
    def get_empty_store():
        return {"down_cross": [], "mid_cross": [], "up_cross": [],
                "down_self": [],  "mid_self": [],  "up_self": []}

    def forward(self, attn, is_cross: bool, place_in_unet: str):
        key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
        if attn.shape[1] <= (self.max_size) ** 2:  # avoid memory overhead
            self.step_store[key].append(attn)
        return attn

    def between_steps(self):
        if len(self.attention_store) == 0:
            self.attention_store = self.step_store
        else:
            for key in self.attention_store:
                for i in range(len(self.attention_store[key])):
                    self.attention_store[key][i] += self.step_store[key][i]
        self.step_store = self.get_empty_store()

    def get_average_attention(self):
        average_attention = {key: [item for item in self.step_store[key]] for key in self.step_store}
        return average_attention

    def reset(self):
        super(AttentionStore, self).reset()
        self.step_store = self.get_empty_store()
        self.attention_store = {}

    def __init__(self, base_size=64, max_size=None):
        super(AttentionStore, self).__init__()
        self.step_store = self.get_empty_store()
        self.attention_store = {}
        self.base_size = base_size
        if max_size is None:
            self.max_size = self.base_size // 2
        else:
            self.max_size = max_size


class PatchEmbed(nn.Module):
    """
    Image to Patch Embedding.
    """

    def __init__(
        self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768
    ):
        """
        Args:
            kernel_size (Tuple): kernel size of the projection layer.
            stride (Tuple): stride of the projection layer.
            padding (Tuple): padding size of the projection layer.
            in_chans (int): Number of input image channels.
            embed_dim (int):  embed_dim (int): Patch embedding dimension.
        """
        super().__init__()

        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
        )

    def forward(self, x):
        x = self.proj(x)
        # B C H W -> B H W C
        x = x.permute(0, 2, 3, 1)
        return x


def get_rel_pos(q_size, k_size, rel_pos):
    """
    Get relative positional embeddings according to the relative positions of
        query and key sizes.
    Args:
        q_size (int): size of query q.
        k_size (int): size of key k.
        rel_pos (Tensor): relative position embeddings (L, C).

    Returns:
        Extracted positional embeddings according to relative positions.
    """
    max_rel_dist = int(2 * max(q_size, k_size) - 1)
    # Interpolate rel pos if needed.
    if rel_pos.shape[0] != max_rel_dist:
        # Interpolate rel pos.
        rel_pos_resized = F.interpolate(
            rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
            size=max_rel_dist,
            mode="linear",
        )
        rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
    else:
        rel_pos_resized = rel_pos

    # Scale the coords with short length if shapes for q and k are different.
    q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
    k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
    relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)

    return rel_pos_resized[relative_coords.long()]


def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size):
    """
    Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
    https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py   # noqa B950
    Args:
        attn (Tensor): attention map.
        q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
        rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
        rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
        q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
        k_size (Tuple): spatial sequence size of key k with (k_h, k_w).

    Returns:
        attn (Tensor): attention map with added relative positional embeddings.
    """
    q_h, q_w = q_size
    k_h, k_w = k_size
    Rh = get_rel_pos(q_h, k_h, rel_pos_h)
    Rw = get_rel_pos(q_w, k_w, rel_pos_w)

    B, _, dim = q.shape
    r_q = q.reshape(B, q_h, q_w, dim)
    rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
    rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)

    attn = (
        attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
    ).view(B, q_h * q_w, k_h * k_w)

    return attn


class Attention(nn.Module):
    """Multi-head Attention block with relative position embeddings."""

    def __init__(
        self,
        dim,
        num_heads=8,
        qkv_bias=True,
        use_rel_pos=False,
        rel_pos_zero_init=True,
        input_size=None,
    ):
        """
        Args:
            dim (int): Number of input channels.
            num_heads (int): Number of attention heads.
            qkv_bias (bool:  If True, add a learnable bias to query, key, value.
            rel_pos (bool): If True, add relative positional embeddings to the attention map.
            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
            input_size (int or None): Input resolution for calculating the relative positional
                parameter size.
        """
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim**-0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)

        self.use_rel_pos = use_rel_pos
        if self.use_rel_pos:
            # initialize relative positional embeddings
            self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
            self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))

            if not rel_pos_zero_init:
                trunc_normal_(self.rel_pos_h, std=0.02)
                trunc_normal_(self.rel_pos_w, std=0.02)

    def forward(self, x):
        # B, H, W, _ = x.shape
        B, T, _ = x.shape
        H=16
        W=32
        qkv = self.qkv(x).reshape(B, H*W+1, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.reshape(3, B * self.num_heads, H*W+1, -1).unbind(0)

        attn = (q * self.scale) @ k.transpose(-2, -1)

        attn = attn.softmax(dim=-1)
        x = (attn @ v).view(B, self.num_heads, H*W+1, -1).permute(0, 2, 1, 3).reshape(B, H*W+1, -1)
        x = self.proj(x)

        return x
    

def window_partition(x, window_size):
    """
    Partition into non-overlapping windows with padding if needed.
    Args:
        x (tensor): input tokens with [B, H, W, C].
        window_size (int): window size.

    Returns:
        windows: windows after partition with [B * num_windows, window_size, window_size, C].
        (Hp, Wp): padded height and width before partition
    """
    B, H, W, C = x.shape

    pad_h = (window_size - H % window_size) % window_size
    pad_w = (window_size - W % window_size) % window_size
    if pad_h > 0 or pad_w > 0:
        x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
    Hp, Wp = H + pad_h, W + pad_w

    x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows, (Hp, Wp)


def window_unpartition(windows, window_size, pad_hw, hw):
    """
    Window unpartition into original sequences and removing padding.
    Args:
        x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
        window_size (int): window size.
        pad_hw (Tuple): padded height and width (Hp, Wp).
        hw (Tuple): original height and width (H, W) before padding.

    Returns:
        x: unpartitioned sequences with [B, H, W, C].
    """
    Hp, Wp = pad_hw
    H, W = hw
    B = windows.shape[0] // (Hp * Wp // window_size // window_size)
    x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)

    if Hp > H or Wp > W:
        x = x[:, :H, :W, :].contiguous()
    return x


class Block(nn.Module):
    """Transformer blocks with support of window attention and residual propagation blocks"""

    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=True,
        drop_path=0.0,
        norm_layer=nn.LayerNorm,
        act_layer=nn.GELU,
        use_rel_pos=False,
        rel_pos_zero_init=True,
        window_size=0,
        use_residual_block=False,
        input_size=None,
    ):
        """
        Args:
            dim (int): Number of input channels.
            num_heads (int): Number of attention heads in each ViT block.
            mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
            qkv_bias (bool): If True, add a learnable bias to query, key, value.
            drop_path (float): Stochastic depth rate.
            norm_layer (nn.Module): Normalization layer.
            act_layer (nn.Module): Activation layer.
            use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
            rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
            window_size (int): Window size for window attention blocks. If it equals 0, then not
                use window attention.
            use_residual_block (bool): If True, use a residual block after the MLP block.
            input_size (int or None): Input resolution for calculating the relative positional
                parameter size.
        """
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            use_rel_pos=use_rel_pos,
            rel_pos_zero_init=rel_pos_zero_init,
            input_size=input_size if window_size == 0 else (window_size, window_size),
        )

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer)

        self.window_size = window_size

        self.use_residual_block = use_residual_block

    def forward(self, x):
        shortcut = x
        x = self.norm1(x)
        # Window partition
        if self.window_size > 0:
            H, W = x.shape[1], x.shape[2]
            x, pad_hw = window_partition(x, self.window_size)

        x = self.attn(x)
        # Reverse window partition
        if self.window_size > 0:
            x = window_unpartition(x, self.window_size, pad_hw, (H, W))

        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x


class TaskEmbeddingExtractor(nn.Module):
    def __init__(self):
        super().__init__()

        self.H = 256
        self.W = 256

        self.a1 = nn.Parameter(torch.Tensor([1.0]))
        self.b1 = nn.Parameter(torch.Tensor([0.0]))
        self.a2 = nn.Parameter(torch.Tensor([1.0]))
        self.b2 = nn.Parameter(torch.Tensor([0.0]))
        self.a3 = nn.Parameter(torch.Tensor([1.0]))
        self.b3 = nn.Parameter(torch.Tensor([0.0]))

        self.patch_embed = PatchEmbed(
            kernel_size=(16, 16),
            stride=(16, 16),
            in_chans=3,
            embed_dim=1024,
        )

        num_positions = 16*32+1
        self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, 1024), requires_grad=True)
        
        self.token_task = nn.Parameter(torch.zeros(1, 1024), requires_grad=True)

        depth = 3
        self.blocks = nn.ModuleList()
        for i in range(depth):
            block = Block(
                dim=1024,
                num_heads=16,
                input_size=(16, 32),
            )
            self.blocks.append(block)
        
        self.linear_320 = nn.Linear(1024, 320)
        self.linear_640 = nn.Linear(1024, 640)
        self.linear_1280 = nn.Linear(1024, 1280)
    
    def forward(self, image_support, label_support):
        bs = image_support.shape[0]
        label_support_transform1 = label_support*self.a1 + self.b1
        label_support_transform2 = label_support*self.a2 + self.b2
        label_support_transform3 = label_support*self.a3 + self.b3
        label_support_transform1 = label_support_transform1.mean(dim = 1)
        label_support_transform2 = label_support_transform2.mean(dim = 1)
        label_support_transform3 = label_support_transform3.mean(dim = 1)
        label_support_transform = torch.stack([label_support_transform1, label_support_transform2, label_support_transform3], dim = 1)
        input_concat = torch.cat([image_support, label_support_transform], dim = 3)
        input_token = self.patch_embed(input_concat)
        input_token = input_token.view(bs, 16*32, 1024)
        input_token = torch.cat([self.token_task.unsqueeze(0).repeat(bs, 1, 1), input_token], dim = 1)
        input_token_pos = input_token + self.pos_embed # (bs, 513, 1024)
        
        x = input_token_pos
        for idx, blk in enumerate(self.blocks):
            x = blk(x)
        
        task_token = x[:, 0, :].squeeze() # (B, 1024)
        task_embedding = task_token.mean(0).unsqueeze(0) # (1, 1024)

        task_embedding_320 = self.linear_320(task_embedding) # (1, 320)
        task_embedding_640 = self.linear_640(task_embedding) # (1, 640)
        task_embedding_1280 = self.linear_1280(task_embedding) # (1, 1280)
        
        task_embedding = (task_embedding_320, task_embedding_640, task_embedding_1280)

        return task_embedding

def register_hier_output(model):
    self = model.diffusion_model

    from ldm.modules.diffusionmodules.util import checkpoint, timestep_embedding
    def forward(x, timesteps, id_router = None, task_embedding = None):
        hs = []
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
        emb = self.time_embed(t_emb)

        # downsampling
        h = x.type(self.dtype)
        num_block = 0
        for module in self.input_blocks:
            # import pdb; pdb.set_trace()
            if (id_router is None):
                h = module(h, emb)
            else:
                h = module(h, emb, id_router = id_router, task_embedding = task_embedding)
            hs.append(h)
            num_block += 1
        
        # middle
        if (id_router is None):
            h = self.middle_block(h, emb)
        else:
            h = self.middle_block(h, emb, id_router = id_router, task_embedding = task_embedding)
        num_block+=1

        # upsampling
        out_list = []
        for i_out, module in enumerate(self.output_blocks):
            h = torch.cat([h, hs.pop()], dim=1)
            if (id_router is None):
                h = module(h, emb)
            else:
                h = module(h, emb, id_router = id_router, task_embedding = task_embedding)
            num_block += 1
            if i_out in [1, 4, 7]:
                out_list.append(h)
        h = h.type(x.dtype)
        
        out_list.append(h)
        
        return out_list
    
    self.forward = forward


class UNetWrapper(nn.Module):
    def __init__(self, unet, use_attn=True, base_size=512, max_attn_size=None, attn_selector='up_cross+down_cross') -> None:
        super().__init__()
        self.unet = unet
        self.attention_store = AttentionStore(base_size=base_size // 8, max_size=max_attn_size)
        self.size16 = base_size // 32
        self.size32 = base_size // 16
        self.size64 = base_size // 8
        self.use_attn = use_attn

        if self.use_attn: # pass
            register_attention_control(self.unet, self.attention_store)
        register_hier_output(self.unet)
        self.attn_selector = attn_selector.split('+')

    def forward(self, x, timesteps, id_router = None, task_embedding = None):
        if self.use_attn: # pass
            self.attention_store.reset()
        
        out_list = self.unet(x, timesteps, id_router = id_router, task_embedding = task_embedding)
        
        if self.use_attn: # pass
            avg_attn = self.attention_store.get_average_attention()
            attn16, attn32, attn64 = self.process_attn(avg_attn)
            out_list[1] = torch.cat([out_list[1], attn16], dim=1)
            out_list[2] = torch.cat([out_list[2], attn32], dim=1)
            if attn64 is not None:
                out_list[3] = torch.cat([out_list[3], attn64], dim=1)
        
        return out_list[::-1]

    def process_attn(self, avg_attn):
        attns = {self.size16: [], self.size32: [], self.size64: []}
        for k in self.attn_selector:
            for up_attn in avg_attn[k]:
                size = int(math.sqrt(up_attn.shape[1]))
                attns[size].append(rearrange(up_attn, 'b (h w) c -> b c h w', h=size))
        attn16 = torch.stack(attns[self.size16]).mean(0)
        attn32 = torch.stack(attns[self.size32]).mean(0)
        if len(attns[self.size64]) > 0:
            attn64 = torch.stack(attns[self.size64]).mean(0)
        else:
            attn64 = None
        return attn16, attn32, attn64


class UniDense(nn.Module):
    def __init__(self, replica_factor, list_task, args):
        super().__init__()

        self.replica_factor = replica_factor
        self.list_task = list_task

        # self.total_channels = output_dim
        self.total_channels = 1

        self.moe_type = 'normal'
        self.ismoe = True

        embed_dim = 192
        
        channels_in = embed_dim*8
        channels_out = embed_dim
        
        dim_condition = 768
        pathdir_model = args.pathdir_model
        
        # self.t_noise = args.t_noise
        self.t_noise = -1

        # config of stable diffusion model
        config = OmegaConf.load('./args_SD_MoE.yaml')

        config.model.params.first_stage_config.params.ddconfig.out_ch = self.total_channels

        if pathdir_model is None:
            config.model.params.ckpt_path = '../checkpoints/v1-5-pruned-emaonly.ckpt'
        else:
            config.model.params.ckpt_path = f'{pathdir_model}/stable-diffusion-v1-5/v1-5-pruned-emaonly.ckpt'

        self.scale_factor = config.model.params.scale_factor

        sd_model = instantiate_from_config(config.model) # If use self.sd_model, the performance will drop. The reason is unknown.

        sqrt_alphas_cumprod = sd_model.sqrt_alphas_cumprod
        sqrt_one_minus_alphas_cumprod = sd_model.sqrt_one_minus_alphas_cumprod
        self.register_buffer('sqrt_alphas_cumprod', sqrt_alphas_cumprod)
        self.register_buffer('sqrt_one_minus_alphas_cumprod', sqrt_one_minus_alphas_cumprod)

        self.autoencoder_KL = sd_model.first_stage_model
        self.unet = UNetWrapper(sd_model.model, use_attn=False)
    
        del sd_model.cond_stage_model
        del self.unet.unet.diffusion_model.out

        for name, param in self.autoencoder_KL.named_parameters():
            if ('decoder.conv_out' in name):
                param.requires_grad = True
            else:
                param.requires_grad = False

        self.aggregator = SimpleDecoding(out_channels = 4, dims = [320,640,1280,1280])

        dict_channel = {
            'segment_semantic': 12,
            'normal': 3, 
            'depth_euclidean': 5, 
            'depth_zbuffer': 1, 
            'edge_texture': 5,
            'edge_occlusion': 5, 
            'keypoints2d': 1, 
            'keypoints3d': 1, 
            'principal_curvature': 2, 
            'reshading': 1,
        }

        self.conv_layer_task_specific = nn.ModuleList([
            torch.nn.Conv2d(128,
                            dict_channel[task],
                            kernel_size=3,
                            stride=1,
                            padding=1
            ) for task in list_task
        ])

        for task in args.img_types_test:
            self.conv_layer_task_specific.append(
                torch.nn.Conv2d(128,
                                dict_channel[task],
                                kernel_size=3,
                                stride=1,
                                padding=1
                )
            )

        self.conv_layer_meta = torch.nn.Conv2d(
            128,
            1,
            kernel_size=3,
            stride=1,
            padding=1
        )

        self.task_extractor = TaskEmbeddingExtractor()

    def init_conv_layer(self, id_task):
        c = self.conv_layer_task_specific[id_task].weight.shape[0]
        self.conv_layer_task_specific[id_task].weight.data.copy_(self.conv_layer_meta.weight.repeat(c, 1, 1, 1))

    def q_sample(self, x_start, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))
        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)

    def get_zloss(self):
        unet = self.unet.unet.diffusion_model
        z_loss = 0

        # downsampling
        for module in unet.input_blocks:
            for layer in module:
                if isinstance(layer, MoE_Transformer):
                    for blk in layer.moe_transformer_blocks:
                        if hasattr(blk.attn1, 'num_experts'):
                            aux_loss = blk.attn1.q_proj.get_aux_loss_and_clear()
                            z_loss = z_loss + aux_loss
                        if hasattr(blk.ff, 'num_experts'):
                            aux_loss = blk.ff.get_aux_loss_and_clear()
                            z_loss = z_loss + aux_loss
        
        # middle
        for layer in unet.middle_block:
            if isinstance(layer, MoE_Transformer):
                for blk in layer.moe_transformer_blocks:
                    if hasattr(blk.attn1, 'num_experts'):
                        aux_loss = blk.attn1.q_proj.get_aux_loss_and_clear()
                        z_loss = z_loss + aux_loss
                    if hasattr(blk.ff, 'num_experts'):
                        aux_loss = blk.ff.get_aux_loss_and_clear()
                        z_loss = z_loss + aux_loss

        # upsampling
        for module in unet.output_blocks:
            for layer in module:
                if isinstance(layer, MoE_Transformer):
                    for blk in layer.moe_transformer_blocks:
                        if hasattr(blk.attn1, 'num_experts'):
                            aux_loss = blk.attn1.q_proj.get_aux_loss_and_clear()
                            z_loss = z_loss + aux_loss
                        if hasattr(blk.ff, 'num_experts'):
                            aux_loss = blk.ff.get_aux_loss_and_clear()
                            z_loss = z_loss + aux_loss

        return z_loss

    def forward_features(self, x):
        # import pdb; pdb.set_trace() 
        b, c, h, w = x.shape
        x = x * 2.0 - 1.0  # normalize to [-1, 1]
        
        # encode
        with torch.no_grad():
            latents = self.scale_factor * self.autoencoder_KL.encode(x).mode() # (BS, 4, 32, 32)

        # # add noise
        if (self.t_noise >= 0):
            latents = self.q_sample(latents, torch.tensor([self.t_noise] * b).to(x.device))

        t = torch.full((b,), self.t_noise + 1, device=x.device).long()

        output = []
        for id_task in tqdm(range(len(self.list_task))):
            conv_feats = self.unet(latents, t, id_router = id_task)

            x_c1 = conv_feats[0] # (BS, 320, 32, 32)
            x_c2 = conv_feats[1] # (BS, 640, 16, 16)
            x_c3 = conv_feats[2] # (BS, 1280, 8, 8)
            x_c4 = conv_feats[3] # (BS, 1280, 4, 4)
            x = self.aggregator(x_c4, x_c3, x_c2, x_c1) # (BS, 640, 32, 32)
            
            output.append(x)
        
        return output

    def forward(self, x, id_task = None, image_support = None, label_support = None):
        if not (id_task is None):
            task_embedding = self.task_extractor(image_support, label_support) # (1024,)

            b, c, h, w = x.shape
            x = x * 2.0 - 1.0  # normalize to [-1, 1]
            
            # encode
            with torch.no_grad():
                latents = self.scale_factor * self.autoencoder_KL.encode(x).mode() # (BS, 4, 32, 32)

            # # add noise
            if (self.t_noise >= 0):
                latents = self.q_sample(latents, torch.tensor([self.t_noise] * b).to(x.device))

            t = torch.full((b,), self.t_noise + 1, device=x.device).long()
            
            conv_feats = self.unet(latents, t, id_router = 8, task_embedding = task_embedding)

            x_c1 = conv_feats[0] # (BS, 320, 32, 32)
            x_c2 = conv_feats[1] # (BS, 640, 16, 16)
            x_c3 = conv_feats[2] # (BS, 1280, 8, 8)
            x_c4 = conv_feats[3] # (BS, 1280, 4, 4)
            x = self.aggregator(x_c4, x_c3, x_c2, x_c1) # (BS, 640, 32, 32)

            x = 1. / self.scale_factor * x
            x = self.autoencoder_KL.decode(x) # (BS, 128, 256, 256)
            x = self.conv_layer_task_specific[id_task](x)
            x = torch.tanh(x)
            x = (x + 1.0) / 2.0
            return x, None

        else:
            output = self.forward_features(x)

            for id_task, the_type in tqdm(enumerate(self.list_task)):
                x = output[id_task]
                x = 1. / self.scale_factor * x
                x = self.autoencoder_KL.decode(x) # (BS, 128, 256, 256)
                x = self.conv_layer_task_specific[id_task](x)
                x = torch.tanh(x)
                x = (x + 1.0) / 2.0
                output[id_task] = x

            z_loss = self.get_zloss()

            return output, z_loss


class MTVisionTransformer(timm.models.vision_transformer.VisionTransformer):
    """ Vision Transformer with support for global average pooling
    """
    def __init__(self, img_types, embed_dim=768, global_pool=True, **kwargs):
        super(MTVisionTransformer, self).__init__(embed_dim=embed_dim, **kwargs)
        self.taskGating = False
        self.ismoe = False

        self.moe_type = 'normal'

        self.img_types = [type_ for type_ in img_types if type_ != 'rgb']
        assert global_pool == True
        del self.head
        norm_layer = kwargs['norm_layer']
        self.fc_norm = norm_layer(embed_dim)

        # create task head
        self.task_heads = []
        type_to_channel = {'depth_euclidean':1, 'depth_zbuffer':1, 'edge_occlusion':1, 'edge_texture':1, 'keypoints2d':1, 'keypoints3d':1, 'normal':3, 'principal_curvature':2,  'reshading':3, 'rgb':3, 'segment_semantic':18, 'segment_unsup2d':1, 'segment_unsup25d':1}
        image_height, image_width = self.patch_embed.img_size
        patch_height, patch_width = self.patch_embed.patch_size
        assert image_height == 224 and image_width == 224
        for t in range(len(self.img_types)): ###
            img_type = self.img_types[t]
            if 'class' in img_type:
                class_num = 1000 if img_type == 'class_object' else 365
                self.task_heads.append(
                        # Use the cls token
                        nn.Sequential(
                            nn.LayerNorm(embed_dim),
                            nn.Linear(embed_dim, class_num)
                        )
                    )
            else: # enter
                channel = type_to_channel[img_type]
                self.task_heads.append(
                        # Use the other token
                        nn.Sequential(
                            Rearrange('b (h w) d -> (b h w) d', h = image_height//patch_height, w= image_width//patch_width),
                            nn.Linear(embed_dim, patch_height * patch_width * channel),
                            Rearrange('(b h w) (j k c) -> b (h j) (w k) c', h = image_height//patch_height, w = image_width//patch_width, j=patch_height, k=patch_width, c=channel),
                        )
                    )
        self.task_heads = nn.ModuleList(self.task_heads)

        self.task_embedding = nn.Parameter(torch.randn(1, len(self.img_types), embed_dim))

        self.apply(self._init_weights)

    def forward_features(self, x, task_rank, task):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed

        x = self.pos_drop(x)

        # apply Transformer blocks

        for blk in self.blocks:
            x = x + self.task_embedding[:,task_rank:task_rank+1, :]
            x = blk(x)

        if 'class' in task:
            x = x[:, 1:, :].mean(dim=1)
            x = self.fc_norm(x)
        else:
            x = self.norm(x)
            x = x[:, 1:, :]

        return x, 0

    def forward(self, x, task, get_flop=False):
        task_rank = -1
        for t, the_type in enumerate(self.img_types):
            if the_type == task:
                task_rank = t
                break
        assert task_rank > -1

        x, z_loss = self.forward_features(x, task_rank, task)
        x = self.task_heads[task_rank](x)

        if get_flop:
            return x
        return x, z_loss
